import org.apache.spark.TaskContext
TaskContext
TaskContext allows a task to access contextual information about itself as well as register task listeners.
Using TaskContext you can access local properties that were set by the driver. You can also access task metrics.
You can access the active TaskContext instance using TaskContext.get method.
TaskContext belongs to org.apache.spark package.
|
Note
|
TaskContext is serializable.
|
Contextual Information
-
stageIdis the id of the stage the task belongs to. -
partitionIdis the id of the partition computed by the task. -
attemptNumberis to denote how many times the task has been attempted (starting from 0). -
taskAttemptIdis the id of the attempt of the task. -
isCompletedreturnstruewhen a task is completed. -
isInterruptedreturnstruewhen a task was killed.
All these attributes are accessible using appropriate getters, e.g. getPartitionId for the partition id.
Registering Task Listeners
Using TaskContext object you can register task listeners for task completion regardless of the final state and task failures only.
addTaskCompletionListener Method
addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext
addTaskCompletionListener methods register a TaskCompletionListener listener to be executed on task completion.
|
Note
|
It will be executed regardless of the final state of a task - success, failure, or cancellation. |
val rdd = sc.range(0, 5, numSlices = 1)
import org.apache.spark.TaskContext
val printTaskInfo = (tc: TaskContext) => {
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|-------------------""".stripMargin
println(msg)
}
rdd.foreachPartition { _ =>
val tc = TaskContext.get
tc.addTaskCompletionListener(printTaskInfo)
}
addTaskFailureListener Method
addTaskFailureListener(listener: TaskFailureListener): TaskContext
addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext
addTaskFailureListener methods register a TaskFailureListener listener to be executed on task failure only. It can be executed multiple times since a task can be re-attempted when it fails.
val rdd = sc.range(0, 2, numSlices = 2)
import org.apache.spark.TaskContext
val printTaskErrorInfo = (tc: TaskContext, error: Throwable) => {
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|error: ${error.toString}
|-------------------""".stripMargin
println(msg)
}
val throwExceptionForOddNumber = (n: Long) => {
if (n % 2 == 1) {
throw new Exception(s"No way it will pass for odd number: $n")
}
}
// FIXME It won't work.
rdd.map(throwExceptionForOddNumber).foreachPartition { _ =>
val tc = TaskContext.get
tc.addTaskFailureListener(printTaskErrorInfo)
}
// Listener registration matters.
rdd.mapPartitions { (it: Iterator[Long]) =>
val tc = TaskContext.get
tc.addTaskFailureListener(printTaskErrorInfo)
it
}.map(throwExceptionForOddNumber).count
Accessing Local Properties — getLocalProperty Method
getLocalProperty(key: String): String
You can use getLocalProperty method to access local properties that were initially set by the driver using SparkContext.setLocalProperty.
Task Metrics
taskMetrics(): TaskMetrics
taskMetrics method is part of the Developer API that allows to access the instance of TaskMetrics for a task.
getMetricsSources Method
getMetricsSources(sourceName: String): Seq[Source]
getMetricsSources allows to access all the metrics sources by sourceName which are associated with the instance that runs the task.
Accessing Active TaskContext — get Method
get(): TaskContext
get method returns the TaskContext instance for an active task (as a TaskContextImpl object). There can only be one instance and tasks can use the object to access contextual information about themselves.
val rdd = sc.range(0, 3, numSlices = 3)
scala> rdd.partitions.size
res0: Int = 3
rdd.foreach { n =>
import org.apache.spark.TaskContext
val tc = TaskContext.get
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|-------------------""".stripMargin
println(msg)
}
|
Note
|
TaskContext object uses ThreadLocal to keep it thread-local, i.e. to associate state with the thread of a task.
|
TaskContextImpl
TaskContextImpl is the only implementation of TaskContext abstract class.
|
Caution
|
FIXME |
-
stage
-
partition
-
task attempt
-
attempt number
-
runningLocally = false
|
Caution
|
FIXME Where and how is TaskMemoryManager used?
|
Creating TaskContextImpl Instance
|
Caution
|
FIXME |
markInterrupted
|
Caution
|
FIXME |